import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
import matplotlib.ticker as ticker
import os

# Create figure with two subplots sharing y-axis
fig, axs = plt.subplots(1, 2, figsize=(20, 7), sharey=False)

# Store legend handles to maintain line styles
legend_handles = []

# Process both datasets
for idx, dataset in enumerate(['MNIST', 'FashionMNIST']):

    # Method configurations
    methods = ['SiPBA','PZOBO','AID-CG', 'AID-FP']
    colors = ['#d62728','#1f77b4', '#ff7f0e', '#2ca02c']
    line_styles = ['-', '-', '-', '-']  # All solid lines
    max_time = 300  # Time cutoff in seconds

    ax = axs[idx]  # Current subplot

    # Create inset axis for zoomed view
    axins = inset_axes(ax, width="60%", height="60%", loc='center',
                    bbox_to_anchor=(0.2, 0.1, 0.6, 0.6),
                    bbox_transform=ax.transAxes)

    # Plot each method
    for method, color, ls in zip(methods, colors, line_styles):
        try:
            # Load precomputed results
            data = np.load(f"result/{dataset}_averaged_results_{method}.npy", allow_pickle=True).item()
        except FileNotFoundError:
            print(f"File not found: {method}")
            continue

        all_times = data["mean_time"]
        all_accs = data["mean_acc"]

        # Align time series to shortest length
        min_length = min(len(t) for t in all_times)
        aligned_times = [times[:min_length] for times in all_times]
        aligned_accs = np.array([accs[:min_length] for accs in all_accs])
        mean_acc = np.mean(aligned_accs, axis=0)

        # Truncate data at max_time
        truncate_idx = next((i for i, t in enumerate(aligned_times[0]) if t > max_time), len(aligned_times[0]))
        truncated_times = aligned_times[0][:truncate_idx]
        truncated_accs = mean_acc[:truncate_idx]
        std_acc = np.std(aligned_accs, axis=0)[:truncate_idx]

        final_target = truncated_accs[-1]  # e.g., 0.951
        for t, acc in zip(truncated_times, truncated_accs):
            if dataset == 'MNIST':
                if acc >= 0.90:
                    print(f"[{dataset}] {method} reaches {final_target*100:.2f}% at time {t:.2f}s")
                    break
            elif dataset == 'FashionMNIST':
                if acc >= 0.75:
                    print(f"[{dataset}] {method} reaches {final_target*100:.2f}% at time {t:.2f}s")
                    break



        # Plot main curve and store line object
        line = ax.plot(truncated_times, truncated_accs, label=method,
                color=color, linestyle=ls, linewidth=5)[0]
        
        # Save line object from first subplot for legend
        if idx == 0:
            legend_handles.append(line)
            
        # Add standard deviation shading
        ax.fill_between(truncated_times,
                        truncated_accs - std_acc,
                        truncated_accs + std_acc,
                        color=color, alpha=0.15)

        # Plot same data in inset
        axins.plot(truncated_times, truncated_accs, color=color, linestyle=ls, linewidth=5)
        axins.fill_between(truncated_times,
                        truncated_accs - std_acc,
                        truncated_accs + std_acc,
                        color=color, alpha=0.15)

    # Configure main plot
    ax.set_xlabel('Time (S)', fontsize=25)
    
    # Only add y-axis label to the left plot
    if idx == 0:
        ax.set_ylabel('Test accuracy', fontsize=25)  
    
    ax.set_title(f'{dataset}', fontsize=25)
    ax.tick_params(axis='both', which='major', labelsize=25, width=2)
    ax.grid(True, linestyle=':', alpha=0.7)
    ax.set_xlim(0, 200)
    ax.set_ylim(0, 1.05)
    ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=6))
    ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=6))
    # Only add y-axis ticks to the left plot
    if idx == 0:
        ax.yaxis.set_major_locator(ticker.MultipleLocator(0.2))  
        ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: '{:.1f}'.format(x) if x <= 1.01 else '')) 
    '''else:
        # Remove redundant tick labels on right plot
        ax.tick_params(axis='y', which='both', labelleft=False)'''
    
    # Set zoom area for inset
    x1, x2 = 200 - 50, 200  # Last 50 seconds
    if dataset == 'FashionMNIST':
        y1, y2 = 0.7, 0.95  # High-accuracy region
    elif dataset == 'MNIST':
        y1, y2 = 0.9, 1.0   # Near-perfect region
    axins.set_xlim(x1, x2)
    axins.set_ylim(y1, y2)
    
    # Configure inset ticks
    axins.xaxis.set_major_locator(ticker.MaxNLocator(nbins=3))
    axins.yaxis.set_major_locator(ticker.MaxNLocator(nbins=3))
    axins.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: '{:.2f}'.format(x) if x <= 1.0 else ''))
    axins.tick_params(axis='both', labelsize=25)
    axins.grid(True, linestyle=':', alpha=0.7)

    # Add connector lines between main plot and inset
    mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")

# Adjust layout to make space for legend
#plt.tight_layout()
fig.subplots_adjust(top=0.82, wspace=0.15)  # Reduce space between subplots

# Create unified legend using actual line objects
fig.legend(handles=legend_handles, loc='upper center', 
           fontsize=25, frameon=True, shadow=True, 
           bbox_to_anchor=(0.5, 1.0), ncol=4)
# Save high-quality image
pic_dir = "./pic"
if not os.path.exists(pic_dir):
    os.makedirs(pic_dir)
plt.savefig(os.path.join(pic_dir, 'combined_mnist_fashionmnist.png'), dpi=300, bbox_inches='tight')



plt.show()